{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Move your annotations to data/custom/labels/. \n",
    "2. The dataloader expects that the annotation file corresponding to the image data/custom/images/train.jpg has the path data/custom/labels/train.txt. \n",
    "2.1. Each row in the annotation file should define one bounding box, using the syntax label_idx x_center y_center width height. \n",
    "The coordinates should be scaled [0, 1], \n",
    "and the label_idx should be zero-indexed and correspond to the row number of the class name in data/custom/classes.names."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Annotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['一次性快餐盒', '书籍纸张', '充电宝', '剩饭剩菜', '包', '垃圾桶', '塑料器皿', '塑料玩具', '塑料衣架', '大骨头', '干电池', '快递纸袋', '插头电线', '旧衣服', '易拉罐', '枕头', '果皮果肉', '毛绒玩具', '污损塑料', '污损用纸', '洗护用品', '烟蒂', '牙签', '玻璃器皿', '砧板', '筷子', '纸盒纸箱', '花盆', '茶叶渣', '菜帮菜叶', '蛋壳', '调料瓶', '软膏', '过期药物', '酒瓶', '金属厨具', '金属器皿', '金属食品罐', '锅', '陶瓷器皿', '鞋', '食用油桶', '饮料瓶', '鱼骨']\n"
     ]
    }
   ],
   "source": [
    "def xyxy2xyhw(bbox, width, height):\n",
    "    # 归一化\n",
    "    flag = True\n",
    "    bbox[0] = bbox[0]/width\n",
    "    bbox[1] = bbox[1]/height\n",
    "    bbox[2] = bbox[2]/width\n",
    "    \n",
    "    bbox[3] = bbox[3]/height\n",
    "    \n",
    "    if(bbox[2]>1 or bbox[3]>1):\n",
    "        flag = False\n",
    "    \n",
    "    ##\n",
    "    w = bbox[2] - bbox[0]\n",
    "    h = bbox[3] - bbox[1]\n",
    "    x_center = bbox[0] + w / 2\n",
    "    y_center = bbox[1] + h / 2\n",
    "    return [x_center, y_center, w, h], flag\n",
    "\n",
    "source_root = '/store/dataset/rubbish/VOC2007/'\n",
    "des_root = '/store/dataset/rubbish_yolo/labels'\n",
    "import os\n",
    "imagesetPath = os.path.join(source_root, 'ImageSets', 'Main', 'trainval.txt')\n",
    "\n",
    "import numpy as np\n",
    "imagenames = np.loadtxt(imagesetPath, np.str)\n",
    "\n",
    "from xml.etree import ElementTree as ET\n",
    "AnnotationRoot = os.path.join(source_root, 'Annotations')\n",
    "\n",
    "classnames = list(np.loadtxt(\"classes.names\", np.str))\n",
    "print(classnames)\n",
    "\n",
    "for imagename in imagenames:\n",
    "    desPath = os.path.join(des_root, imagename+'.txt')\n",
    "    AnnotationPath = os.path.join(AnnotationRoot, imagename+'.xml')\n",
    "    tree = ET.parse(AnnotationPath)\n",
    "    height = float(tree.findall(\"./size/height\")[0].text)\n",
    "    width = float(tree.findall('./size/width')[0].text)\n",
    "    #print(height, width)\n",
    "    \n",
    "    with open(desPath, 'w') as f:\n",
    "        for obj in tree.findall(\"object\"):\n",
    "            cls = obj.find(\"name\").text\n",
    "            cls_ind = classnames.index(cls)\n",
    "            #print(\"cls_name:%s, cls_ind:%d\"%(cls, cls_ind))\n",
    "            bbox = obj.find(\"bndbox\")\n",
    "            bbox = [float(bbox.find(x).text) for x in [\"xmin\", \"ymin\",\"xmax\",\"ymax\"]]\n",
    "            ##xyxy=>xyhw\n",
    "            bbox, flag = xyxy2xyhw(bbox,width, height)\n",
    "            if flag==False:\n",
    "                print(desPath)\n",
    "            saved = np.concatenate((np.array(cls_ind)[np.newaxis], np.array(bbox)), 0)\n",
    "            #print(saved)\n",
    "            saved = saved[:,np.newaxis].transpose()\n",
    "            np.savetxt(f, saved, fmt=\"%.1f %.6f %.6f %.6f %.6f\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[False, False, False],\n",
       "         [False, False, False]]], device='cuda:0')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "a = torch.cuda.BoolTensor(1,2,3).fill_(0)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "44\n",
      "['一次性快餐盒', '书籍纸张', '充电宝', '剩饭剩菜', '包', '垃圾桶', '塑料器皿', '塑料玩具', '塑料衣架', '大骨头', '干电池', '快递纸袋', '插头电线', '旧衣服', '易拉罐', '枕头', '果皮果肉', '毛绒玩具', '污损塑料', '污损用纸', '洗护用品', '烟蒂', '牙签', '玻璃器皿', '砧板', '筷子', '纸盒纸箱', '花盆', '茶叶渣', '菜帮菜叶', '蛋壳', '调料瓶', '软膏', '过期药物', '酒瓶', '金属厨具', '金属器皿', '金属食品罐', '锅', '陶瓷器皿', '鞋', '食用油桶', '饮料瓶', '鱼骨']\n"
     ]
    }
   ],
   "source": [
    "def load_classes(path):\n",
    "    \"\"\"\n",
    "    Loads class labels at 'path'\n",
    "    \"\"\"\n",
    "    fp = open(path, \"r\")\n",
    "    names = fp.read().strip().split(\"\\n\")\n",
    "    return names\n",
    "\n",
    "path = \"/home/lrh/program/git/object_detection/code/PyTorch-YOLOv3/data/coco.names\"\n",
    "\n",
    "path = \"/store/dataset/rubbish_yolo/classes.names\"\n",
    "classes = load_classes(path)\n",
    "print(len(classes))\n",
    "print(classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## train val split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_root = '/store/dataset/rubbish/VOC2007/'\n",
    "des_root = '/store/dataset/rubbish_yolo/'\n",
    "import os\n",
    "trainimagesetPath = os.path.join(source_root, 'ImageSets', 'Main', 'train.txt')\n",
    "testimagesetPath = os.path.join(source_root, 'ImageSets', 'Main', 'train.txt')\n",
    "\n",
    "trainfilepaths = [os.path.join(des_root, 'images', filename+'.jpg') for filename in list(np.loadtxt(trainimagesetPath, np.str))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "traintxt = os.path.join(des_root, 'train.txt')\n",
    "print(traintxt)\n",
    "print(len(trainfilepaths))\n",
    "with open(traintxt, 'w') as f:\n",
    "    np.savetxt(f,np.array(trainfilepaths),fmt='%s')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "source_root = '/store/dataset/rubbish/VOC2007/'\n",
    "des_root = '/store/dataset/rubbish_yolo/'\n",
    "import os\n",
    "\n",
    "valimagesetPath = os.path.join(source_root, 'ImageSets', 'Main', 'val.txt')\n",
    "valfilepaths = [os.path.join(des_root, 'images', filename+'.jpg') for filename in list(np.loadtxt(valimagesetPath, np.str))]\n",
    "\n",
    "valtxt = os.path.join(des_root, 'valid.txt')\n",
    "print(valtxt)\n",
    "print(len(valfilepaths))\n",
    "with open(valtxt, 'w') as f:\n",
    "    np.savetxt(f,np.array(valfilepaths),fmt='%s')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data from coco for augmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. 生成 trainval.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "trainvaltxt = '/store/dataset/rubbish/coco2014/ImageSets/Main/trainval.txt'\n",
    "\n",
    "imagenames = [imagename.split('.')[0] for imagename in os.listdir('/store/dataset/rubbish/coco2014/JPEGImages')]\n",
    "\n",
    "with open(trainvaltxt, 'w') as f:\n",
    "    np.savetxt(f, np.array(imagenames), fmt='%s')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 转化成 训练的 格式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['一次性快餐盒', '书籍纸张', '充电宝', '剩饭剩菜', '包', '垃圾桶', '塑料器皿', '塑料玩具', '塑料衣架', '大骨头', '干电池', '快递纸袋', '插头电线', '旧衣服', '易拉罐', '枕头', '果皮果肉', '毛绒玩具', '污损塑料', '污损用纸', '洗护用品', '烟蒂', '牙签', '玻璃器皿', '砧板', '筷子', '纸盒纸箱', '花盆', '茶叶渣', '菜帮菜叶', '蛋壳', '调料瓶', '软膏', '过期药物', '酒瓶', '金属厨具', '金属器皿', '金属食品罐', '锅', '陶瓷器皿', '鞋', '食用油桶', '饮料瓶', '鱼骨']\n"
     ]
    }
   ],
   "source": [
    "def xyxy2xyhw(bbox, width, height):\n",
    "    # 归一化\n",
    "    flag = True\n",
    "    bbox[0] = bbox[0]/width\n",
    "    bbox[1] = bbox[1]/height\n",
    "    bbox[2] = bbox[2]/width\n",
    "    \n",
    "    bbox[3] = bbox[3]/height\n",
    "    \n",
    "    if(bbox[2]>1 or bbox[3]>1):\n",
    "        flag = False\n",
    "    \n",
    "    ##\n",
    "    w = bbox[2] - bbox[0]\n",
    "    h = bbox[3] - bbox[1]\n",
    "    x_center = bbox[0] + w / 2\n",
    "    y_center = bbox[1] + h / 2\n",
    "    return [x_center, y_center, w, h], flag\n",
    "\n",
    "source_root = '/store/dataset/rubbish/coco2014/'\n",
    "des_root = '/store/dataset/rubbish_yolo/coco/labels'\n",
    "import os\n",
    "imagesetPath = os.path.join(source_root, 'ImageSets', 'Main', 'trainval.txt')\n",
    "\n",
    "##复制images文件夹 和 trainval.txt\n",
    "import shutil\n",
    "shutil.copytree(os.path.join(source_root, 'JPEGImages'), os.path.join(des_root, '../images'))\n",
    "shutil.copy(imagesetPath, os.path.join(des_root, '../trainval.txt'))\n",
    "\n",
    "import numpy as np\n",
    "imagenames = np.loadtxt(imagesetPath, np.str)\n",
    "\n",
    "from xml.etree import ElementTree as ET\n",
    "AnnotationRoot = os.path.join(source_root, 'Annotations')\n",
    "\n",
    "classnames = list(np.loadtxt(\"/store/dataset/rubbish/train_classes.txt\", np.str))\n",
    "print(classnames)\n",
    "\n",
    "for imagename in imagenames:\n",
    "    desPath = os.path.join(des_root, imagename+'.txt')\n",
    "    AnnotationPath = os.path.join(AnnotationRoot, imagename+'.xml')\n",
    "    tree = ET.parse(AnnotationPath)\n",
    "    height = float(tree.findall(\"./size/height\")[0].text)\n",
    "    width = float(tree.findall('./size/width')[0].text)\n",
    "    #print(height, width)\n",
    "    \n",
    "    with open(desPath, 'w') as f:\n",
    "        for obj in tree.findall(\"object\"):\n",
    "            cls = obj.find(\"name\").text\n",
    "            if cls == '书籍':\n",
    "                cls = '书籍纸张'\n",
    "            cls_ind = classnames.index(cls)\n",
    "            #print(\"cls_name:%s, cls_ind:%d\"%(cls, cls_ind))\n",
    "            bbox = obj.find(\"bndbox\")\n",
    "            bbox = [float(bbox.find(x).text) for x in [\"xmin\", \"ymin\",\"xmax\",\"ymax\"]]\n",
    "            ##xyxy=>xyhw\n",
    "            bbox, flag = xyxy2xyhw(bbox,width, height)\n",
    "            if flag==False:\n",
    "                print(desPath)\n",
    "            saved = np.concatenate((np.array(cls_ind)[np.newaxis], np.array(bbox)), 0)\n",
    "            #print(saved)\n",
    "            saved = saved[:,np.newaxis].transpose()\n",
    "            np.savetxt(f, saved, fmt=\"%.1f %.6f %.6f %.6f %.6f\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将数据合并生成三个新的训练数据集 trainval.txt, train+coco1000.txt, trainval+coco1000.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. trainval.txt 合并 train.txt 和 val.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14964\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "txtroot = '/store/dataset/rubbish_yolo'\n",
    "def combine2txt(filename1, filename2):\n",
    "    traintxt = os.path.join(txtroot, filen)\n",
    "    validtxt = os.path.join(txtroot, 'valid.txt')\n",
    "\n",
    "    with open(traintxt, 'r') as f:\n",
    "        trainpaths = np.loadtxt(f, np.str)\n",
    "\n",
    "    with open(validtxt, 'r') as f:\n",
    "        validpaths = np.loadtxt(f, np.str)\n",
    "\n",
    "    trainvalpaths = []\n",
    "\n",
    "    trainvalpaths.extend(trainpaths)\n",
    "    trainvalpaths.extend(validpaths)\n",
    "\n",
    "    print(len(trainvalpaths))\n",
    "\n",
    "    with open(os.path.join(txtroot, 'trainval.txt'), 'w') as f:\n",
    "        np.savetxt(f, np.array(trainvalpaths), fmt='%s')\n",
    "combine2txt('train.txt', 'valid.txt', )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 生成coco1000.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'1': 1000, '35': 1000, '27': 1000}\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "txtroot = '/store/dataset/rubbish_yolo'\n",
    "traintxt = os.path.join(txtroot, 'train.txt')\n",
    "cocotxt = os.path.join(txtroot, 'coco/trainval.txt')\n",
    "\n",
    "with open(cocotxt, 'r') as f:\n",
    "    imagepaths = [os.path.join(os.path.dirname(cocotxt), 'images', imagename+'.jpg') \n",
    "                  for imagename in np.loadtxt(f, np.str)]\n",
    "\n",
    "labelpaths = [imagepath.replace('images', 'labels').replace('.jpg', '.txt') for imagepath in imagepaths]\n",
    "\n",
    "#print(labelpaths)\n",
    "num_c = dict()\n",
    "num_c['1'] = 0\n",
    "num_c['35'] = 0\n",
    "num_c['27'] = 0\n",
    "\n",
    "\n",
    "cocoimagepaths = []\n",
    "\n",
    "thres = 1000\n",
    "\n",
    "for ind, labelpath in enumerate(labelpaths):\n",
    "    label = np.loadtxt(labelpath, np.str).astype(np.float)\n",
    "    if label.ndim<2:\n",
    "        label = label[np.newaxis, :]\n",
    "    cls = np.unique(label[:, 0])\n",
    "    for c in cls:\n",
    "        c = str(int(c))\n",
    "        if num_c[c]>=thres:\n",
    "            ## 如果其他类没有到达边界，则也保存该图像\n",
    "            continue;\n",
    "            #break;\n",
    "        else:\n",
    "            ## 将 该图像 包含到 文件中\n",
    "            num_c[c] += 1\n",
    "            cocoimagepaths.append(imagepaths[ind])\n",
    "print(num_c)\n",
    "            \n",
    "with open(os.path.join(txtroot, 'coco1000.txt'), 'w') as f:\n",
    "    np.savetxt(f, np.array(cocoimagepaths), fmt='%s')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
